library(tidyverse)
library(forcats)
library(here)
library(zoo)

# load index of information about the video
index <- 
  read_csv(str_c(here(),"/data/index_files_2022.csv")) %>% 
  filter(filename_base == fname)

folder <- fname %>% str_split("_") %>% .[[1]] %>% .[1]

# load index of stimuli for each frame
stimuli <- 
  read_csv(str_c(here(),"/data/stimuli/",index$filename_stimuli)) %>%
  mutate(stimulus = ordered(stimulus, levels = c("normoxia", "hypoxia", "sham hypoxia", "NaCN", "FCCP"))) %>%
  select(filename_base, frame, stimulus, baseline)
  
# load and clean ROI data that was output from imagej
load_rois <- function(fname, type = "glomus"){
  rois <- 
    read_csv(
      str_c(here(),"/../../data/images/",folder,"/",fname),
      col_type =
        cols_only(
          Label = col_character(),
          Area = col_double(),
          Mean = col_double(),
          Slice = col_double(),
          Frame = col_double()
        )
    ) %>%
    rename_all(tolower) %>%
    rename(`F` = mean) %>%
    # pull ROI number from long imagej name
    mutate(roi = map_depth(label, 0, ~ str_split(., ":"))) %>%
    mutate(roi = map(roi, ~ .[2] %>% as.numeric()) %>% unlist()) %>%
    mutate(roi_type = type) %>%
    select(-label) 
}

rois <- load_rois(index$filename_rois, type = "glomus")

if(!is.na(index$filename_rois_glia) & !is.na(index$filename_rois_vasc)){
  rois_glia <- load_rois(index$filename_rois_glia, type = "glia")
  rois_vasc <- load_rois(index$filename_rois_vasc, type = "vasc")
  rois <- rois %>% rbind(rois_glia) %>% rbind(rois_vasc)
}

if(!is.na(index$filename_rois_glia) & is.na(index$filename_rois_vasc)){
  rois_glia <- load_rois(index$filename_rois_glia, type = "glia")
  rois <- rois %>% rbind(rois_glia)
}

if(is.na(index$filename_rois_glia) & !is.na(index$filename_rois_vasc)){
  rois_vasc <- load_rois(index$filename_rois_vasc, type = "vasc")
  rois <- rois %>% rbind(rois_vasc)
}

# calculate baseline subtracted F/F0 of the traces
traces <- 
  stimuli %>%
  left_join(rois, by = "frame") %>%
  mutate(
    time_seconds = (frame-1) * index$seconds_per_cycle,
    time_minutes = time_seconds / 60,
  ) %>%
  group_by(roi_type, roi) %>%
  mutate(
    # centered 3 point rolling average of raw fluorescence
    `F_avg` = rollmean(`F`, 3, align = "center", fill = NA),
    # baseline, using 3 point rolling average, only include indicated frames
    `F_b` = ifelse(baseline == TRUE, `F_avg`, NA),
    # baseline, interpolate excluded frames
    `F_b` = na.approx(`F_b`, rule = 2),
    # calculate F/F0 of rolling average baseline
    `F_b/F_b0` = `F_b`/mean(`F_b`[1:4], na.rm=TRUE),
    # calculate raw F/F0
    `F/F0` = `F`/mean(`F`[1:4], na.rm=TRUE),
    # calculate baseline subtracted F/F0 (that is, calculate delta F/F0)
     `F/F0_bs` = `F/F0` - `F_b/F_b0`
  ) %>%
  ungroup() %>%
  select(filename_base, slice, roi, frame, time_seconds, time_minutes, stimulus, everything()) %>%
  arrange(filename_base, roi)

write_csv(traces, str_c(here(),"/data/traces/",fname,"_traces.csv"))
write_rds(traces, str_c(here(),"/data/traces/",fname,"_traces.rds"))

traces_summary <-
  traces %>%
  group_by(roi_type, roi) %>%
  summarise(
    F0 = `F`[1],
    Fend = `F_b`[length(`F_b`)],
    deltaF0Fend = F0 - Fend,
    relativeF0Fend = (F0 - Fend) / F0
  ) %>%
  ungroup()

write_csv(traces_summary, str_c(here(),"/data/summary/",fname,"_traces_summary.csv"))
write_rds(traces_summary, str_c(here(),"/data/summary/",fname,"_traces_summary.rds"))

# Select the baseline stimulus (usually "normoxia")
stimulus_baseline <- 
  traces %>% 
  count(stimulus) %>% 
  arrange(desc(n)) %>% 
  top_n(1) %>% 
  pull(stimulus)

pull_window <- function(df, stimulus_current, n_before = 2, n_after = 6){
  indices_stim <- which(df$stimulus == stimulus_current)
  indices_stim_lab <- rep("stimulus",length(indices_stim))
  
  before <- seq(indices_stim[1] - n_before, indices_stim[1]-1, 1)
  if(before[1] < 1){before <- seq(1, indices_stim[1]-1, 1)}
  before_lab <- rep("pre-stimulus",length(before))
  
  after <- seq(indices_stim[length(indices_stim)]+1, indices_stim[length(indices_stim)] + n_after, 1)
  if(after[length(after)] > nrow(df)){
    after <- seq(indices_stim[length(indices_stim)]+1, nrow(df), 1)
  }
  after_lab <- rep("post-stimulus",length(after))
  
  indices_stim <- c(before, indices_stim, after)
  df_subset <- 
    df %>% 
    slice(indices_stim) %>% 
    mutate(
      stimulus_group = stimulus_current,
      window_type = c(before_lab, indices_stim_lab, after_lab),
      frame_relative = frame - min(frame) + 1
    )
  
  return(df_subset)
}

traces_separated <-
  stimuli %>%
  filter(stimulus != stimulus_baseline) %>%
  pull(stimulus) %>%
  unique() %>%
  map_dfr(
    ~ pull_window(
      stimuli, 
      ., 
      n_before = 4,
      n_after = 6)
  ) %>%
  mutate(
    time_relative_seconds = (frame_relative-1) * index$seconds_per_cycle,
    time_relative_minutes = time_relative_seconds / 60
  ) %>%
  left_join(
    traces %>% select(-baseline),
    by = c("filename_base", "frame", "stimulus")
  ) %>%
  group_by(roi_type, roi, stimulus_group) %>%
  ungroup() %>%
  select(filename_base, slice, roi_type, roi, frame, time_seconds, time_minutes, stimulus_group, window_type, stimulus, everything()) %>%
  arrange(filename_base, roi)

write_csv(traces_separated, str_c(here(),"/data/traces/",fname,"_traces_separated.csv"))
write_rds(traces_separated, str_c(here(),"/data/traces/",fname,"_traces_separated.rds"))

traces_separated_summary <-
  traces_separated %>%
  group_by(roi_type, roi, stimulus_group) %>% #filter(stimulus_group == "FCCP")
  summarise(
    start_time_seconds = 
      time_relative_seconds[(min(which(window_type == "stimulus")))],
    
    stop_time_seconds = 
      time_relative_seconds[(max(which(window_type == "stimulus")))+2],
    
    start_time_minutes = start_time_seconds/60,
    
    stop_time_minutes = stop_time_seconds/60,
    
    # stimulus mean and peak measured against the point right before stimulation window
    stim_base = `F/F0_bs`[which(window_type == "pre-stimulus") %>% tail(1)],
    
    # stimulus mean is the mean of the stimulus window, excluding the first point,
    # and including 2 points following the end of the stimulus
    stim_mean_raw = mean(
      `F/F0_bs`[
        c(which(window_type == "stimulus"), # all stimulus frames
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 1, # stim frames plus 1
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 2) # stim frames plus 2
      ] %>% 
        tail(-1), # remove first point
      na.rm=TRUE
    ),
    
    # stimulus mean short is the mean of the stimulus window, excluding the first point,
    # and excluding the last 3 cycles to make comparable to timing of stimulus
    # in calcium imaging experiments
    stim_mean_short_raw = mean(
      `F/F0_bs`[
        c(which(window_type == "stimulus"),
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) - 2)
      ] %>% 
        tail(-1), 
      na.rm=TRUE
    ),
    
    stim_mean = stim_mean_raw - stim_base,
    
    stim_mean_short = stim_mean_short_raw - stim_base,
    
    # stimulus peak max is the maximum of the stimulus window, excluding the first point,
    # and including 2 points following the end of the stimulus
    stim_peak_max_raw = max(
      `F/F0_bs`[
        c(which(window_type == "stimulus"),
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 1,
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 2)
      ] %>% 
        tail(-1),
      na.rm = TRUE
    ),
    
    # stimulus peak min is the minimum of the stimulus window, excluding the first point,
    # and including 2 points following the end of the stimulus
    stim_peak_min_raw = min(
      `F/F0_bs`[
        c(which(window_type == "stimulus"),
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 1,
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) + 2)
      ] %>% 
        tail(-1), 
      na.rm=TRUE
    ),
    
    # stimulus peak is the larger in absolute value of the minimum vs maximum values
    stim_peak = c(stim_peak_min_raw - stim_base, stim_peak_max_raw - stim_base)[which.max(c(abs(stim_peak_min_raw - stim_base), abs(stim_peak_max_raw - stim_base)))],

    # stimulus peak max is the maximum peak minus the base
    stim_peak_max = stim_peak_max_raw - stim_base,
    
    # stimulus peak short is the maximum of the stimulus window, excluding the first point,
    # and excluding the last 3 cycles to make comparable to Erwin's timing of the stimulus
    stim_peak_max_short_raw = max(
      `F/F0_bs`[
        c(which(window_type == "stimulus"),
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) - 2)
      ] %>% 
        tail(-1), 
      na.rm=TRUE),
    
    # stimulus peak min is the minimum of the stimulus window, excluding the first point,
    # and including 2 points following the end of the stimulus
    stim_peak_min_short_raw = min(
      `F/F0_bs`[
        c(which(window_type == "stimulus"),
          length(which(window_type == "stimulus")) + min(which(window_type == "stimulus")) - 2)
      ] %>% 
        tail(-1), 
      na.rm=TRUE
    ),
    
    # stimulus peak is the larger in absolute value of the minimum vs maximum values
    stim_peak_short = c(stim_peak_min_short_raw - stim_base, stim_peak_max_short_raw - stim_base)[which.max(c(abs(stim_peak_min_short_raw - stim_base), abs(stim_peak_max_short_raw - stim_base)))],

    # stimulus peak max short is the maximum peak minus the base, using the shorter stimulus window
    stim_peak_max_short = stim_peak_max_short_raw - stim_base
    
  ) %>%
  ungroup() %>%
  left_join(traces_summary)

write_csv(traces_separated_summary, str_c(here(),"/data/summary/",fname,"_traces_separated_summary.csv"))
write_rds(traces_separated_summary, str_c(here(),"/data/summary/",fname,"_traces_separated_summary.rds"))
  
